from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs an example numerical reconstruction
# to report in the supplement

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

# This defines a reconstruction with R = 0.5
resolution = int(probe_maxk * 0.8 / 2) * 2

field_size = int(np.ceil(resolution + probe_maxk)) * 2
# This sets the focal spot size of our BLR probe
probe_fov_radius = (field_size * 4)//10

# This defines that we do our error comparison within the central
# half of the image.
error_check_padding = 1/ 4


# This generates an ideal BLR probe with the given parameters
nearfield = generate_blr_probe([field_size,field_size],
                               probe_fov_radius, 
                               probe_maxk)

# We generate a random image
real = np.random.randn(resolution, resolution)
imag = np.random.randn(resolution, resolution)
original_image = real + 1j * imag
original_image = complex_to_torch(original_image).to(dtype=t.float32)

# All this does is upsample the probe in real space to avoid
# aliasing in the probe-object interaction
simulation_nearfield = expand_probe(nearfield, original_image)

# Now we perform the simulation
exit_wave = interact(original_image, simulation_nearfield)
diffraction = propagate(exit_wave)
intensities = measure(diffraction, nearfield.shape)

# We normalize it tto a consistent value
intensities /= t.sum(intensities)
intensities *= 10 * probe_maxk**2 


# We store the final loss,error, and result from each reconstruction
losses = []
errors = []
retrievals = []
for attempt in range(50):
    print('Running attempt',attempt)    

    # This attempts a reconstruction with a randomized initialization
    retrieval, loss = reconstruct(intensities, nearfield, resolution, 0.4, 1000, GPU=GPU, schedule=True)
    ret = torch_to_complex(retrieval)
    im = torch_to_complex(original_image)

    # We calculate the overall RMS error in the central region
    padding = int(resolution * error_check_padding)
    error, originals, comparisons = compare_images(ret, im, np.s_[padding:-padding,padding:-padding])
    print(loss[-1])
    
    
    # And save the RMS image error and diffraction loss
    retrievals.append(ret)
    losses.append(loss)
    errors.append(error)
    


results = {'probe':torch_to_complex(nearfield),'image':torch_to_complex(original_image),'losses':losses, 'errors':errors, 'retrievals':retrievals}
with open('data/example reconstruction.pickle', 'wb') as f:
    pickle.dump(results, f)
